from typing import Optional, List, Union

from langchain.callbacks import CallbackManager
from langchain.chat_models.base import BaseChatModel
from langchain.llms import BaseLLM
from langchain.schema import BaseMessage, BaseLanguageModel, HumanMessage
from core.constant import llm_constant
from core.callback_handler.llm_callback_handler import LLMCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \
    DifyStdOutCallbackHandler
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
from core.llm.error import LLMBadRequestError
from core.llm.llm_builder import LLMBuilder
from core.chain.main_chain_builder import MainChainBuilder
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
from core.llm.streamable_open_ai import StreamableOpenAI
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
    ReadOnlyConversationTokenDBBufferSharedMemory
from core.memory.read_only_conversation_token_db_string_buffer_shared_memory import \
    ReadOnlyConversationTokenDBStringBufferSharedMemory
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import OutLinePromptTemplate
from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
from models.model import App, AppModelConfig, Account, Conversation, Message


class Completion:
    @classmethod
    def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict,
                 user: Account, conversation: Optional[Conversation], streaming: bool, is_override: bool = False):
        """
        errors: ProviderTokenNotInitError
        """
        cls.validate_query_tokens(app.tenant_id, app_model_config, query)

        memory = None
        if conversation:
            # get memory of conversation (read-only)
            memory = cls.get_memory_from_conversation(
                tenant_id=app.tenant_id,
                app_model_config=app_model_config,
                conversation=conversation
            )

            inputs = conversation.inputs

        conversation_message_task = ConversationMessageTask(
            task_id=task_id,
            app=app,
            app_model_config=app_model_config,
            user=user,
            conversation=conversation,
            is_override=is_override,
            inputs=inputs,
            query=query,
            streaming=streaming
        )

        # build main chain include agent
        main_chain = MainChainBuilder.to_langchain_components(
            tenant_id=app.tenant_id,
            agent_mode=app_model_config.agent_mode_dict,
            memory=ReadOnlyConversationTokenDBStringBufferSharedMemory(memory=memory) if memory else None,
            conversation_message_task=conversation_message_task
        )

        chain_output = ''
        if main_chain:
            chain_output = main_chain.run(query)

        # run the final llm
        try:
            cls.run_final_llm(
                tenant_id=app.tenant_id,
                mode=app.mode,
                app_model_config=app_model_config,
                query=query,
                inputs=inputs,
                chain_output=chain_output,
                conversation_message_task=conversation_message_task,
                memory=memory,
                streaming=streaming
            )
        except ConversationTaskStoppedException:
            return

    @classmethod
    def run_final_llm(cls, tenant_id: str, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict,
                      chain_output: str,
                      conversation_message_task: ConversationMessageTask,
                      memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory], streaming: bool):
        final_llm = LLMBuilder.to_llm_from_model(
            tenant_id=tenant_id,
            model=app_model_config.model_dict,
            streaming=streaming
        )

        # get llm prompt
        prompt = cls.get_main_llm_prompt(
            mode=mode,
            llm=final_llm,
            pre_prompt=app_model_config.pre_prompt,
            query=query,
            inputs=inputs,
            chain_output=chain_output,
            memory=memory
        )

        final_llm.callback_manager = cls.get_llm_callback_manager(final_llm, streaming, conversation_message_task)

        cls.recale_llm_max_tokens(
            final_llm=final_llm,
            prompt=prompt,
            mode=mode
        )

        response = final_llm.generate([prompt])

        return response

    @classmethod
    def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict, chain_output: Optional[str],
                            memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
            Union[str | List[BaseMessage]]:
        pre_prompt = PromptBuilder.process_template(pre_prompt) if pre_prompt else pre_prompt
        if mode == 'completion':
            prompt_template = OutLinePromptTemplate.from_template(
                template=("Use the following pieces of [CONTEXT] to answer the question at the end. "
                          "If you don't know the answer, "
                          "just say that you don't know, don't try to make up an answer. \n"
                          "```\n"
                          "[CONTEXT]\n"
                          "{context}\n"
                          "```\n" if chain_output else "")
                         + (pre_prompt + "\n" if pre_prompt else "")
                         + "{query}\n"
            )

            if chain_output:
                inputs['context'] = chain_output

            prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
            prompt_content = prompt_template.format(
                query=query,
                **prompt_inputs
            )

            if isinstance(llm, BaseChatModel):
                # use chat llm as completion model
                return [HumanMessage(content=prompt_content)]
            else:
                return prompt_content
        else:
            messages: List[BaseMessage] = []

            system_message = None
            if pre_prompt:
                # append pre prompt as system message
                system_message = PromptBuilder.to_system_message(pre_prompt, inputs)

            if chain_output:
                # append context as system message, currently only use simple stuff prompt
                context_message = PromptBuilder.to_system_message(
                    """Use the following pieces of [CONTEXT] to answer the users question. 
If you don't know the answer, just say that you don't know, don't try to make up an answer.
```
[CONTEXT]
{context}
```""",
                    {'context': chain_output}
                )

                if not system_message:
                    system_message = context_message
                else:
                    system_message.content = context_message.content + "\n\n" + system_message.content

            if system_message:
                messages.append(system_message)

            human_inputs = {
                "query": query
            }

            # construct main prompt
            human_message = PromptBuilder.to_human_message(
                prompt_content="{query}",
                inputs=human_inputs
            )

            if memory:
                # append chat histories
                tmp_messages = messages.copy() + [human_message]
                curr_message_tokens = memory.llm.get_messages_tokens(tmp_messages)
                rest_tokens = llm_constant.max_context_token_length[
                                  memory.llm.model_name] - memory.llm.max_tokens - curr_message_tokens
                rest_tokens = max(rest_tokens, 0)
                history_messages = cls.get_history_messages_from_memory(memory, rest_tokens)
                messages += history_messages

            messages.append(human_message)

            return messages

    @classmethod
    def get_llm_callback_manager(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
                                 streaming: bool, conversation_message_task: ConversationMessageTask) -> CallbackManager:
        llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
        if streaming:
            callback_handlers = [llm_callback_handler, DifyStreamingStdOutCallbackHandler()]
        else:
            callback_handlers = [llm_callback_handler, DifyStdOutCallbackHandler()]

        return CallbackManager(callback_handlers)

    @classmethod
    def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
                                         max_token_limit: int) -> \
            List[BaseMessage]:
        """Get memory messages."""
        memory.max_token_limit = max_token_limit
        memory_key = memory.memory_variables[0]
        external_context = memory.load_memory_variables({})
        return external_context[memory_key]

    @classmethod
    def get_memory_from_conversation(cls, tenant_id: str, app_model_config: AppModelConfig,
                                     conversation: Conversation,
                                     **kwargs) -> ReadOnlyConversationTokenDBBufferSharedMemory:
        # only for calc token in memory
        memory_llm = LLMBuilder.to_llm_from_model(
            tenant_id=tenant_id,
            model=app_model_config.model_dict
        )

        # use llm config from conversation
        memory = ReadOnlyConversationTokenDBBufferSharedMemory(
            conversation=conversation,
            llm=memory_llm,
            max_token_limit=kwargs.get("max_token_limit", 2048),
            memory_key=kwargs.get("memory_key", "chat_history"),
            return_messages=kwargs.get("return_messages", True),
            input_key=kwargs.get("input_key", "input"),
            output_key=kwargs.get("output_key", "output"),
            message_limit=kwargs.get("message_limit", 10),
        )

        return memory

    @classmethod
    def validate_query_tokens(cls, tenant_id: str, app_model_config: AppModelConfig, query: str):
        llm = LLMBuilder.to_llm_from_model(
            tenant_id=tenant_id,
            model=app_model_config.model_dict
        )

        model_limited_tokens = llm_constant.max_context_token_length[llm.model_name]
        max_tokens = llm.max_tokens

        if model_limited_tokens - max_tokens - llm.get_num_tokens(query) < 0:
            raise LLMBadRequestError("Query is too long")

    @classmethod
    def recale_llm_max_tokens(cls, final_llm: Union[StreamableOpenAI, StreamableChatOpenAI],
                              prompt: Union[str, List[BaseMessage]], mode: str):
        # recalc max_tokens if sum(prompt_token +  max_tokens) over model token limit
        model_limited_tokens = llm_constant.max_context_token_length[final_llm.model_name]
        max_tokens = final_llm.max_tokens

        if mode == 'completion' and isinstance(final_llm, BaseLLM):
            prompt_tokens = final_llm.get_num_tokens(prompt)
        else:
            prompt_tokens = final_llm.get_messages_tokens(prompt)

        if prompt_tokens + max_tokens > model_limited_tokens:
            max_tokens = max(model_limited_tokens - prompt_tokens, 16)
            final_llm.max_tokens = max_tokens

    @classmethod
    def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str,
                                app_model_config: AppModelConfig, user: Account, streaming: bool):
        llm: StreamableOpenAI = LLMBuilder.to_llm(
            tenant_id=app.tenant_id,
            model_name='gpt-3.5-turbo',
            streaming=streaming
        )

        # get llm prompt
        original_prompt = cls.get_main_llm_prompt(
            mode="completion",
            llm=llm,
            pre_prompt=pre_prompt,
            query=message.query,
            inputs=message.inputs,
            chain_output=None,
            memory=None
        )

        original_completion = message.answer.strip()

        prompt = MORE_LIKE_THIS_GENERATE_PROMPT
        prompt = prompt.format(prompt=original_prompt, original_completion=original_completion)

        if isinstance(llm, BaseChatModel):
            prompt = [HumanMessage(content=prompt)]

        conversation_message_task = ConversationMessageTask(
            task_id=task_id,
            app=app,
            app_model_config=app_model_config,
            user=user,
            inputs=message.inputs,
            query=message.query,
            is_override=True if message.override_model_configs else False,
            streaming=streaming
        )

        llm.callback_manager = cls.get_llm_callback_manager(llm, streaming, conversation_message_task)

        cls.recale_llm_max_tokens(
            final_llm=llm,
            prompt=prompt,
            mode='completion'
        )

        llm.generate([prompt])
